//! # Concat
//!
//! Concatenates a list of tensors into a single tensor along a specified axis.
//!
//! **ONNX Spec**: <https://onnx.ai/onnx/operators/onnx__Concat.html>
//!
//! ## Opset Versions
//! - **Opset 1-3**: Initial version
//! - **Opset 4-10**: Updated type support
//! - **Opset 11-12**: More type support
//! - **Opset 13+**: Current version with extended type support
use crate::ir::Argument;

use crate::ir::{ArgType, Node, NodeBuilder, TensorType};
use crate::processor::{
    InputPreferences, InputSpec, NodeProcessor, NodeSpec, OutputPreferences, OutputSpec,
    ProcessError,
};

/// Configuration for Concat operation
#[derive(Debug, Clone)]
pub struct ConcatConfig {
    pub axis: usize,
}

/// Node representation for Concat operation
#[derive(Debug, Clone)]
pub struct ConcatNode {
    pub name: String,
    pub inputs: Vec<Argument>,
    pub outputs: Vec<Argument>,
    pub config: ConcatConfig,
}

pub(crate) struct ConcatProcessor;

impl NodeProcessor for ConcatProcessor {
    type Config = ConcatConfig;

    fn spec(&self) -> NodeSpec {
        NodeSpec {
            min_opset: 4,
            max_opset: None,
            inputs: InputSpec::AtLeast(1),
            outputs: OutputSpec::Exact(1),
        }
    }

    fn input_preferences(
        &self,
        node: &NodeBuilder,
        _opset: usize,
    ) -> Result<Option<InputPreferences>, ProcessError> {
        use crate::processor::ArgPreference;

        if node.inputs.is_empty() {
            return Ok(None);
        }

        // Check if any input is Shape type
        let has_shape = node.inputs.iter().any(|input| input.ty.is_shape());

        if !has_shape {
            return Ok(None);
        }

        // When concatenating with Shape inputs, prefer constant rank-1 tensors to be Shape
        let mut prefs = InputPreferences::new();
        for input in &node.inputs {
            if matches!(&input.ty, ArgType::Tensor(t) if t.rank == 1) {
                // Prefer this constant to be Shape
                prefs = prefs.add(&input.name, ArgPreference::Shape);
            }
        }

        Ok(Some(prefs))
    }

    fn infer_types(
        &self,
        node: &mut NodeBuilder,
        opset: usize,
        _output_preferences: &OutputPreferences,
    ) -> Result<(), ProcessError> {
        // Get reference to config for type inference (not used, but extracted for consistency)
        let _config = self
            .extract_config(node, opset)
            .expect("Config extraction failed");

        // For shapes, axis must be 0 (since they're 1D) - validation already done in extract_config

        // Infer output type

        // Check if we have mixed Shape and rank-1 tensor inputs
        let has_shape = node
            .inputs
            .iter()
            .any(|i| matches!(i.ty, ArgType::Shape(_)));
        let has_rank1_tensor = node
            .inputs
            .iter()
            .any(|i| matches!(&i.ty, ArgType::Tensor(t) if t.rank == 1));

        // Validate all inputs have compatible types (all Tensor or all Shape, except mixed Shape/rank-1 tensor case)
        if !has_shape && !has_rank1_tensor {
            // Regular tensor case - validate all inputs are tensors with same dtype
            let first_dtype = match &node.inputs[0].ty {
                ArgType::Tensor(t) => t.dtype,
                _ => {
                    return Err(ProcessError::TypeMismatch {
                        expected: "Tensor".to_string(),
                        actual: format!("{:?}", node.inputs[0].ty),
                    });
                }
            };

            for (i, input) in node.inputs.iter().enumerate().skip(1) {
                match &input.ty {
                    ArgType::Tensor(t) => {
                        if t.dtype != first_dtype {
                            return Err(ProcessError::TypeMismatch {
                                expected: format!("Tensor with dtype {:?}", first_dtype),
                                actual: format!("Tensor with dtype {:?} at input {}", t.dtype, i),
                            });
                        }
                    }
                    _ => {
                        return Err(ProcessError::TypeMismatch {
                            expected: "Tensor".to_string(),
                            actual: format!("{:?} at input {}", input.ty, i),
                        });
                    }
                }
            }
        }

        if has_shape && has_rank1_tensor {
            // Mixed inputs that will be unified after constant conversion
            // Calculate provisional rank by summing Shape ranks and estimating tensor contributions
            let mut provisional_rank: usize = 0;

            for input in &node.inputs {
                match &input.ty {
                    ArgType::Shape(rank) => {
                        provisional_rank += rank;
                    }
                    ArgType::Tensor(t) if t.rank == 1 => {
                        // For constant tensors, use their actual dimension count
                        // For dynamic tensors, assume 1 element (will be corrected after conversion)
                        let contribution = input.value().as_ref().map(|v| v.shape[0]).unwrap_or(1);
                        provisional_rank += contribution;
                    }
                    _ => {
                        return Err(ProcessError::TypeMismatch {
                            expected: "Shape or rank-1 Tensor".to_string(),
                            actual: format!("{:?}", input.ty),
                        });
                    }
                }
            }

            // Output as Shape type since we have Shape inputs
            // The rank is provisional and will be corrected after constant conversion
            node.outputs[0].ty = ArgType::Shape(provisional_rank);
            return Ok(());
        }

        // Get the first input type - it determines the output type
        let first_input_type = &node.inputs[0].ty;

        match first_input_type {
            ArgType::Tensor(tensor) => {
                node.outputs[0].ty = ArgType::Tensor(TensorType {
                    dtype: tensor.dtype,
                    rank: tensor.rank,
                    static_shape: None,
                });
            }
            ArgType::Shape(_) => {
                // When concatenating shapes, we sum up their ranks
                let total_rank: usize = node
                    .inputs
                    .iter()
                    .map(|input| match &input.ty {
                        ArgType::Shape(rank) => Ok(*rank),
                        _ => Err(ProcessError::TypeMismatch {
                            expected: "Shape".to_string(),
                            actual: format!("{:?}", input.ty),
                        }),
                    })
                    .collect::<Result<Vec<_>, _>>()?
                    .iter()
                    .sum();

                node.outputs[0].ty = ArgType::Shape(total_rank);
            }
            _ => {
                return Err(ProcessError::TypeMismatch {
                    expected: "Tensor or Shape".to_string(),
                    actual: format!("{:?}", first_input_type),
                });
            }
        }

        Ok(())
    }

    fn extract_config(
        &self,
        node: &NodeBuilder,
        _opset: usize,
    ) -> Result<Self::Config, ProcessError> {
        // Extract the axis attribute (required per ONNX spec)
        let mut axis: Option<i64> = None;

        for (key, value) in node.attrs.iter() {
            if key.as_str() == "axis" {
                axis = Some(value.clone().into_i64());
                break;
            }
        }

        let axis = axis.ok_or_else(|| ProcessError::MissingAttribute("axis".to_string()))?;

        // extract the rank based on input type
        let rank = match &node.inputs.first().unwrap().ty {
            ArgType::Tensor(tensor) => tensor.rank as i64,
            ArgType::Shape(_) => 1, // Shapes are 1D
            _ => {
                return Err(ProcessError::TypeMismatch {
                    expected: "Tensor or Shape".to_string(),
                    actual: format!("{:?}", node.inputs.first().unwrap().ty),
                });
            }
        };

        // if axis is negative, it is counted from the end
        let normalized_axis = if axis < 0 { axis + rank } else { axis };

        // TODO: Add validation that normalized_axis is within valid range [0, rank)
        // According to spec, axis must be in range [-r, r-1] where r = rank(inputs)
        // After normalization, should validate: 0 <= normalized_axis < rank
        // TODO: Add test for empty inputs list - spec requires 1+ inputs but not validated
        // TODO: Add test for single input - edge case that should work but may not be tested
        // TODO: Validate all non-concat dimensions match across inputs - currently only dtype checked for tensors
        // TODO: Add test for very large axis values that overflow after normalization

        let config = ConcatConfig {
            axis: normalized_axis as usize,
        };
        Ok(config)
    }

    fn build_node(&self, builder: NodeBuilder, opset: usize) -> Node {
        let config = self
            .extract_config(&builder, opset)
            .expect("Config extraction failed");

        Node::Concat(ConcatNode {
            name: builder.name,
            inputs: builder.inputs,
            outputs: builder.outputs,
            config,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ir::NodeType;
    use crate::node::test_utils::TestNodeBuilder;

    fn create_test_node(axis: i64, input_rank: usize, num_inputs: usize) -> TestNodeBuilder {
        TestNodeBuilder::new(NodeType::Concat, "test_concat")
            .input_tensors_f32::<Vec<usize>>("data", num_inputs, input_rank, None)
            .output_tensor_f32("output", input_rank, None)
            .attr_int("axis", axis)
    }

    #[test]
    fn test_concat_config_basic() {
        let node = create_test_node(1, 3, 2).process(ConcatProcessor, 16);
        let processor = ConcatProcessor;
        let config = processor.extract_config(&node, 16).unwrap();
        assert_eq!(config.axis, 1);
    }

    #[test]
    fn test_concat_config_negative_axis() {
        let node = create_test_node(-2, 3, 2).process(ConcatProcessor, 16);
        let processor = ConcatProcessor;
        let config = processor.extract_config(&node, 16).unwrap();
        assert_eq!(config.axis, 1); // -2 + 3 = 1
    }

    #[test]
    fn test_concat_config_shape_input() {
        let node = TestNodeBuilder::new(NodeType::Concat, "test_concat_shape")
            .input_shape("shape1", 2)
            .input_shape("shape2", 3)
            .output_shape("output", 5)
            .attr_int("axis", 0) // Required attribute
            .process(ConcatProcessor, 16);

        let processor = ConcatProcessor;
        let config = processor.extract_config(&node, 16).unwrap();
        assert_eq!(config.axis, 0); // Shape concat uses axis 0
    }

    #[test]
    fn test_concat_config_missing_axis() {
        let node = TestNodeBuilder::new(NodeType::Concat, "test_concat")
            .input_tensor_f32("data1", 3, None)
            .input_tensor_f32("data2", 3, None)
            .output_tensor_f32("output", 3, None)
            .build();

        let node = node;
        let processor = ConcatProcessor;
        let result = processor.extract_config(&node, 16);
        assert!(matches!(result, Err(ProcessError::MissingAttribute(_))));
    }

    #[test]
    fn test_concat_config_axis_out_of_bounds() {
        let node = TestNodeBuilder::new(NodeType::Concat, "test_concat")
            .input_tensor_f32("data1", 3, None)
            .input_tensor_f32("data2", 3, None)
            .output_tensor_f32("output", 3, None)
            .attr_int("axis", 3)
            .build();

        let processor = ConcatProcessor;
        let result = processor.extract_config(&node, 16);
        assert!(result.is_ok()); // axis 3 is valid, it's normalized to 3 which equals rank
    }

    #[test]
    fn test_concat_update_outputs_shape() {
        let node = TestNodeBuilder::new(NodeType::Concat, "test_concat_shape")
            .input_shape("shape1", 2)
            .input_shape("shape2", 3)
            .input_shape("shape3", 1)
            .output_shape("output", 0) // Will be updated
            .attr_int("axis", 0) // Required attribute
            .process(ConcatProcessor, 16);

        // Check that output is Shape with sum of input ranks
        match &node.outputs[0].ty {
            ArgType::Shape(rank) => assert_eq!(*rank, 6), // 2 + 3 + 1
            _ => panic!("Expected Shape output"),
        }
    }

    #[test]
    fn test_concat_config_shape_negative_axis() {
        let node = TestNodeBuilder::new(NodeType::Concat, "test_concat_shape")
            .input_shape("shape1", 2)
            .input_shape("shape2", 3)
            .output_shape("output", 5)
            .attr_int("axis", -1) // -1 should become 0 for 1D shapes
            .process(ConcatProcessor, 16);

        let processor = ConcatProcessor;
        let config = processor.extract_config(&node, 16).unwrap();
        assert_eq!(config.axis, 0); // -1 + 1 = 0
    }

    #[test]
    fn test_concat_config_shape_invalid_axis() {
        let node = TestNodeBuilder::new(NodeType::Concat, "test_concat_shape")
            .input_shape("shape1", 2)
            .input_shape("shape2", 3)
            .output_shape("output", 5)
            .attr_int("axis", 1)
            .build();

        let processor = ConcatProcessor;
        let result = processor.extract_config(&node, 16);
        assert!(result.is_ok()); // axis 1 is valid for Shape inputs (rank-1)
    }

    #[test]
    fn test_concat_mixed_inputs() {
        let mut node = TestNodeBuilder::new(NodeType::Concat, "test_concat_mixed")
            .input_shape("shape1", 2)
            .input_tensor_f32("tensor1", 3, None)
            .output_shape("output", 0)
            .attr_int("axis", 0)
            .build();

        let processor = ConcatProcessor;
        let prefs = OutputPreferences::new();
        let _config = processor.extract_config(&node, 16).unwrap();
        let result = processor.infer_types(&mut node, 16, &prefs);
        assert!(matches!(result, Err(ProcessError::TypeMismatch { .. })));
    }
}
